/*
 * Copyright 2002-2018 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.test.web.servlet.result;

import org.hamcrest.Matcher;

import org.springframework.test.web.servlet.MvcResult;
import org.springframework.test.web.servlet.ResultMatcher;
import org.springframework.ui.ModelMap;
import org.springframework.validation.BindingResult;
import org.springframework.validation.Errors;
import org.springframework.validation.FieldError;
import org.springframework.web.servlet.ModelAndView;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.springframework.test.util.AssertionErrors.assertEquals;
import static org.springframework.test.util.AssertionErrors.assertTrue;
import static org.springframework.test.util.AssertionErrors.fail;

/**
 * Factory for assertions on the model.
 *
 * <p>An instance of this class is typically accessed via
 * {@link MockMvcResultMatchers#model}.
 *
 * @author Rossen Stoyanchev
 * @since 3.2
 */
public class ModelResultMatchers {

    /**
     * Protected constructor.
     * Use {@link MockMvcResultMatchers#model()}.
     */
    protected ModelResultMatchers() {
    }


    /**
     * Assert a model attribute value with the given Hamcrest {@link Matcher}.
     */
    @SuppressWarnings("unchecked")
    public <T> ResultMatcher attribute(final String name, final Matcher<T> matcher) {
        return result -> {
            ModelAndView mav = getModelAndView(result);
            assertThat("Model attribute '" + name + "'", (T) mav.getModel().get(name), matcher);
        };
    }

    /**
     * Assert a model attribute value.
     */
    public ResultMatcher attribute(final String name, final Object value) {
        return result -> {
            ModelAndView mav = getModelAndView(result);
            assertEquals("Model attribute '" + name + "'", value, mav.getModel().get(name));
        };
    }

    /**
     * Assert the given model attributes exist.
     */
    public ResultMatcher attributeExists(final String... names) {
        return result -> {
            ModelAndView mav = getModelAndView(result);
            for (String name : names) {
                assertTrue("Model attribute '" + name + "' does not exist", mav.getModel().get(name) != null);
            }
        };
    }

    /**
     * Assert the given model attributes do not exist.
     */
    public ResultMatcher attributeDoesNotExist(final String... names) {
        return result -> {
            ModelAndView mav = getModelAndView(result);
            for (String name : names) {
                assertTrue("Model attribute '" + name + "' exists", mav.getModel().get(name) == null);
            }
        };
    }

    /**
     * Assert the given model attribute(s) have errors.
     */
    public ResultMatcher attributeErrorCount(final String name, final int expectedCount) {
        return result -> {
            ModelAndView mav = getModelAndView(result);
            Errors errors = getBindingResult(mav, name);
            assertEquals("Binding/validation error count for attribute '" + name + "', ",
                    expectedCount, errors.getErrorCount());
        };
    }

    /**
     * Assert the given model attribute(s) have errors.
     */
    public ResultMatcher attributeHasErrors(final String... names) {
        return mvcResult -> {
            ModelAndView mav = getModelAndView(mvcResult);
            for (String name : names) {
                BindingResult result = getBindingResult(mav, name);
                assertTrue("No errors for attribute '" + name + "'", result.hasErrors());
            }
        };
    }

    /**
     * Assert the given model attribute(s) do not have errors.
     */
    public ResultMatcher attributeHasNoErrors(final String... names) {
        return mvcResult -> {
            ModelAndView mav = getModelAndView(mvcResult);
            for (String name : names) {
                BindingResult result = getBindingResult(mav, name);
                assertTrue("Unexpected errors for attribute '" + name + "': " + result.getAllErrors(),
                        !result.hasErrors());
            }
        };
    }

    /**
     * Assert the given model attribute field(s) have errors.
     */
    public ResultMatcher attributeHasFieldErrors(final String name, final String... fieldNames) {
        return mvcResult -> {
            ModelAndView mav = getModelAndView(mvcResult);
            BindingResult result = getBindingResult(mav, name);
            assertTrue("No errors for attribute '" + name + "'", result.hasErrors());
            for (final String fieldName : fieldNames) {
                boolean hasFieldErrors = result.hasFieldErrors(fieldName);
                assertTrue("No errors for field '" + fieldName + "' of attribute '" + name + "'", hasFieldErrors);
            }
        };
    }

    /**
     * Assert a field error code for a model attribute using exact String match.
     *
     * @since 4.1
     */
    public ResultMatcher attributeHasFieldErrorCode(final String name, final String fieldName, final String error) {
        return mvcResult -> {
            ModelAndView mav = getModelAndView(mvcResult);
            BindingResult result = getBindingResult(mav, name);
            assertTrue("No errors for attribute '" + name + "'", result.hasErrors());
            FieldError fieldError = result.getFieldError(fieldName);
            if (fieldError == null) {
                fail("No errors for field '" + fieldName + "' of attribute '" + name + "'");
            }
            String code = fieldError.getCode();
            assertTrue("Expected error code '" + error + "' but got '" + code + "'", error.equals(code));
        };
    }

    /**
     * Assert a field error code for a model attribute using a {@link org.hamcrest.Matcher}.
     *
     * @since 4.1
     */
    public <T> ResultMatcher attributeHasFieldErrorCode(final String name, final String fieldName,
                                                        final Matcher<? super String> matcher) {

        return mvcResult -> {
            ModelAndView mav = getModelAndView(mvcResult);
            BindingResult result = getBindingResult(mav, name);
            assertTrue("No errors for attribute: [" + name + "]", result.hasErrors());
            FieldError fieldError = result.getFieldError(fieldName);
            if (fieldError == null) {
                fail("No errors for field '" + fieldName + "' of attribute '" + name + "'");
            }
            String code = fieldError.getCode();
            assertThat("Field name '" + fieldName + "' of attribute '" + name + "'", code, matcher);
        };
    }

    /**
     * Assert the total number of errors in the model.
     */
    public <T> ResultMatcher errorCount(final int expectedCount) {
        return result -> {
            int actualCount = getErrorCount(getModelAndView(result).getModelMap());
            assertEquals("Binding/validation error count", expectedCount, actualCount);
        };
    }

    /**
     * Assert the model has errors.
     */
    public <T> ResultMatcher hasErrors() {
        return result -> {
            int count = getErrorCount(getModelAndView(result).getModelMap());
            assertTrue("Expected binding/validation errors", count != 0);
        };
    }

    /**
     * Assert the model has no errors.
     */
    public <T> ResultMatcher hasNoErrors() {
        return result -> {
            ModelAndView mav = getModelAndView(result);
            for (Object value : mav.getModel().values()) {
                if (value instanceof Errors) {
                    assertTrue("Unexpected binding/validation errors: " + value, !((Errors) value).hasErrors());
                }
            }
        };
    }

    /**
     * Assert the number of model attributes.
     */
    public <T> ResultMatcher size(final int size) {
        return result -> {
            ModelAndView mav = getModelAndView(result);
            int actual = 0;
            for (String key : mav.getModel().keySet()) {
                if (!key.startsWith(BindingResult.MODEL_KEY_PREFIX)) {
                    actual++;
                }
            }
            assertEquals("Model size", size, actual);
        };
    }

    private ModelAndView getModelAndView(MvcResult mvcResult) {
        ModelAndView mav = mvcResult.getModelAndView();
        if (mav == null) {
            fail("No ModelAndView found");
        }
        return mav;
    }

    private BindingResult getBindingResult(ModelAndView mav, String name) {
        BindingResult result = (BindingResult) mav.getModel().get(BindingResult.MODEL_KEY_PREFIX + name);
        if (result == null) {
            fail("No BindingResult for attribute: " + name);
        }
        return result;
    }

    private int getErrorCount(ModelMap model) {
        int count = 0;
        for (Object value : model.values()) {
            if (value instanceof Errors) {
                count += ((Errors) value).getErrorCount();
            }
        }
        return count;
    }

}
